/*
   Copyright The containerd Authors.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
*/

package main

import (
	"fmt"
	"strings"

	"google.golang.org/protobuf/compiler/protogen"
)

// generator is a Go code generator that uses ttrpc.Server and ttrpc.Client.
// Unlike the original gogo version, this doesn't generate serializers for message types and
// let protoc-gen-go handle them.
type generator struct {
	out *protogen.GeneratedFile

	ident struct {
		context     string
		server      string
		client      string
		method      string
		stream      string
		serviceDesc string
		streamDesc  string

		streamServerIdent protogen.GoIdent
		streamClientIdent protogen.GoIdent

		streamServer string
		streamClient string
	}
}

func newGenerator(out *protogen.GeneratedFile) *generator {
	gen := generator{out: out}
	gen.ident.context = out.QualifiedGoIdent(protogen.GoIdent{
		GoImportPath: "context",
		GoName:       "Context",
	})
	gen.ident.server = out.QualifiedGoIdent(protogen.GoIdent{
		GoImportPath: "github.com/containerd/ttrpc",
		GoName:       "Server",
	})
	gen.ident.client = out.QualifiedGoIdent(protogen.GoIdent{
		GoImportPath: "github.com/containerd/ttrpc",
		GoName:       "Client",
	})
	gen.ident.method = out.QualifiedGoIdent(protogen.GoIdent{
		GoImportPath: "github.com/containerd/ttrpc",
		GoName:       "Method",
	})
	gen.ident.stream = out.QualifiedGoIdent(protogen.GoIdent{
		GoImportPath: "github.com/containerd/ttrpc",
		GoName:       "Stream",
	})
	gen.ident.serviceDesc = out.QualifiedGoIdent(protogen.GoIdent{
		GoImportPath: "github.com/containerd/ttrpc",
		GoName:       "ServiceDesc",
	})
	gen.ident.streamDesc = out.QualifiedGoIdent(protogen.GoIdent{
		GoImportPath: "github.com/containerd/ttrpc",
		GoName:       "StreamDesc",
	})

	gen.ident.streamServerIdent = protogen.GoIdent{
		GoImportPath: "github.com/containerd/ttrpc",
		GoName:       "StreamServer",
	}
	gen.ident.streamClientIdent = protogen.GoIdent{
		GoImportPath: "github.com/containerd/ttrpc",
		GoName:       "ClientStream",
	}
	gen.ident.streamServer = out.QualifiedGoIdent(gen.ident.streamServerIdent)
	gen.ident.streamClient = out.QualifiedGoIdent(gen.ident.streamClientIdent)
	return &gen
}

func generate(plugin *protogen.Plugin, input *protogen.File, servicePrefix string) error {
	if len(input.Services) == 0 {
		// Only generate a Go file if the file has some services.
		return nil
	}

	file := plugin.NewGeneratedFile(input.GeneratedFilenamePrefix+"_ttrpc.pb.go", input.GoImportPath)
	file.P("// Code generated by protoc-gen-go-ttrpc. DO NOT EDIT.")
	file.P("// source: ", input.Desc.Path())
	file.P("package ", input.GoPackageName)

	gen := newGenerator(file)
	for _, service := range input.Services {
		service.GoName = servicePrefix + service.GoName
		gen.genService(service)
	}
	return nil
}

func (gen *generator) genService(service *protogen.Service) {
	fullName := service.Desc.FullName()
	p := gen.out

	var methods []*protogen.Method
	var streams []*protogen.Method

	serviceName := service.GoName + "Service"
	p.P("type ", serviceName, " interface{")
	for _, method := range service.Methods {
		var sendArgs, retArgs string
		if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
			streams = append(streams, method)
			sendArgs = fmt.Sprintf("%s_%sServer", service.GoName, method.GoName)
			if !method.Desc.IsStreamingClient() {
				sendArgs = fmt.Sprintf("*%s, %s", p.QualifiedGoIdent(method.Input.GoIdent), sendArgs)
			}
			if method.Desc.IsStreamingServer() {
				retArgs = "error"
			} else {
				retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent))
			}
		} else {
			methods = append(methods, method)
			sendArgs = fmt.Sprintf("*%s", p.QualifiedGoIdent(method.Input.GoIdent))
			retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent))
		}
		p.P(method.GoName, "(", gen.ident.context, ", ", sendArgs, ") ", retArgs)
	}
	p.P("}")
	p.P()

	for _, method := range streams {
		structName := strings.ToLower(service.GoName) + method.GoName + "Server"

		p.P("type ", service.GoName, "_", method.GoName, "Server interface {")
		if method.Desc.IsStreamingServer() {
			p.P("Send(*", method.Output.GoIdent, ") error")
		}
		if method.Desc.IsStreamingClient() {
			p.P("Recv() (*", method.Input.GoIdent, ", error)")

		}
		p.P(gen.ident.streamServer)
		p.P("}")
		p.P()

		p.P("type ", structName, " struct {")
		p.P(gen.ident.streamServer)
		p.P("}")
		p.P()

		if method.Desc.IsStreamingServer() {
			p.P("func (x *", structName, ") Send(m *", method.Output.GoIdent, ") error {")
			p.P("return x.StreamServer.SendMsg(m)")
			p.P("}")
			p.P()
		}

		if method.Desc.IsStreamingClient() {
			p.P("func (x *", structName, ") Recv() (*", method.Input.GoIdent, ", error) {")
			p.P("m := new(", method.Input.GoIdent, ")")
			p.P("if err := x.StreamServer.RecvMsg(m); err != nil {")
			p.P("return nil, err")
			p.P("}")
			p.P("return m, nil")
			p.P("}")
			p.P()
		}
	}

	// registration method
	p.P("func Register", serviceName, "(srv *", gen.ident.server, ", svc ", serviceName, "){")
	p.P(`srv.RegisterService("`, fullName, `", &`, gen.ident.serviceDesc, "{")
	if len(methods) > 0 {
		p.P(`Methods: map[string]`, gen.ident.method, "{")
		for _, method := range methods {
			p.P(`"`, method.GoName, `": func(ctx `, gen.ident.context, ", unmarshal func(interface{}) error)(interface{}, error){")
			p.P("var req ", method.Input.GoIdent)
			p.P("if err := unmarshal(&req); err != nil {")
			p.P("return nil, err")
			p.P("}")
			p.P("return svc.", method.GoName, "(ctx, &req)")
			p.P("},")
		}
		p.P("},")
	}
	if len(streams) > 0 {
		p.P(`Streams: map[string]`, gen.ident.stream, "{")
		for _, method := range streams {
			p.P(`"`, method.GoName, `": {`)
			p.P(`Handler: func(ctx `, gen.ident.context, ", stream ", gen.ident.streamServer, ") (interface{}, error) {")

			structName := strings.ToLower(service.GoName) + method.GoName + "Server"
			var sendArg string
			if !method.Desc.IsStreamingClient() {
				sendArg = "m, "
				p.P("m := new(", method.Input.GoIdent, ")")
				p.P("if err := stream.RecvMsg(m); err != nil {")
				p.P("return nil, err")
				p.P("}")
			}
			if method.Desc.IsStreamingServer() {
				p.P("return nil, svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})")
			} else {
				p.P("return svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})")

			}
			p.P("},")
			if method.Desc.IsStreamingClient() {
				p.P("StreamingClient: true,")
			} else {
				p.P("StreamingClient: false,")
			}
			if method.Desc.IsStreamingServer() {
				p.P("StreamingServer: true,")
			} else {
				p.P("StreamingServer: false,")
			}
			p.P("},")
		}
		p.P("},")
	}
	p.P("})")
	p.P("}")
	p.P()

	clientType := service.GoName + "Client"

	// For consistency with ttrpc 1.0 without streaming, just use
	// the service name if no streams are defined
	clientInterface := serviceName
	if len(streams) > 0 {
		clientInterface = clientType
		// Stream client interfaces are different than the server interface
		p.P("type ", clientInterface, " interface{")
		for _, method := range service.Methods {
			if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
				streams = append(streams, method)
				var sendArg string
				if !method.Desc.IsStreamingClient() {
					sendArg = fmt.Sprintf("*%s, ", p.QualifiedGoIdent(method.Input.GoIdent))
				}
				p.P(method.GoName,
					"(", gen.ident.context, ", ", sendArg,
					") (", service.GoName, "_", method.GoName, "Client, error)")
			} else {
				methods = append(methods, method)
				p.P(method.GoName,
					"(", gen.ident.context, ", ",
					"*", method.Input.GoIdent, ")",
					"(*", method.Output.GoIdent, ", error)")
			}
		}
		p.P("}")
		p.P()
	}

	clientStructType := strings.ToLower(service.GoName) + "Client"
	p.P("type ", clientStructType, " struct{")
	p.P("client *", gen.ident.client)
	p.P("}")
	p.P("func New", clientType, "(client *", gen.ident.client, ")", clientInterface, "{")
	p.P("return &", clientStructType, "{")
	p.P("client:client,")
	p.P("}")
	p.P("}")
	p.P()

	for _, method := range service.Methods {
		var sendArg string
		if !method.Desc.IsStreamingClient() {
			sendArg = ", req *" + gen.out.QualifiedGoIdent(method.Input.GoIdent)
		}

		intName := service.GoName + "_" + method.GoName + "Client"
		var retArg string
		if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
			retArg = intName
		} else {
			retArg = "*" + gen.out.QualifiedGoIdent(method.Output.GoIdent)
		}

		p.P("func (c *", clientStructType, ") ", method.GoName,
			"(ctx ", gen.ident.context, "", sendArg, ") ",
			"(", retArg, ", error) {")

		if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
			var streamingClient, streamingServer, req string
			if method.Desc.IsStreamingClient() {
				streamingClient = "true"
				req = "nil"
			} else {
				streamingClient = "false"
				req = "req"
			}
			if method.Desc.IsStreamingServer() {
				streamingServer = "true"
			} else {
				streamingServer = "false"
			}
			p.P("stream, err := c.client.NewStream(ctx, &", gen.ident.streamDesc, "{")
			p.P("StreamingClient: ", streamingClient, ",")
			p.P("StreamingServer: ", streamingServer, ",")
			p.P("}, ", `"`+fullName+`", `, `"`+method.GoName+`", `, req, `)`)
			p.P("if err != nil {")
			p.P("return nil, err")
			p.P("}")

			structName := strings.ToLower(service.GoName) + method.GoName + "Client"

			p.P("x := &", structName, "{stream}")

			p.P("return x, nil")
			p.P("}")
			p.P()

			// Create interface
			p.P("type ", intName, " interface {")
			if method.Desc.IsStreamingClient() {
				p.P("Send(*", method.Input.GoIdent, ") error")
			}
			if method.Desc.IsStreamingServer() {
				p.P("Recv() (*", method.Output.GoIdent, ", error)")
			} else {
				p.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)")
			}

			p.P(gen.ident.streamClient)
			p.P("}")
			p.P()

			// Create struct
			p.P("type ", structName, " struct {")
			p.P(gen.ident.streamClient)
			p.P("}")
			p.P()

			if method.Desc.IsStreamingClient() {
				p.P("func (x *", structName, ") Send(m *", method.Input.GoIdent, ") error {")
				p.P("return x.", gen.ident.streamClientIdent.GoName, ".SendMsg(m)")
				p.P("}")
				p.P()
			}

			if method.Desc.IsStreamingServer() {
				p.P("func (x *", structName, ") Recv() (*", method.Output.GoIdent, ", error) {")
				p.P("m := new(", method.Output.GoIdent, ")")
				p.P("if err := x.ClientStream.RecvMsg(m); err != nil {")
				p.P("return nil, err")
				p.P("}")
				p.P("return m, nil")
				p.P("}")
				p.P()
			} else {
				p.P("func (x *", structName, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
				p.P("if err := x.ClientStream.CloseSend(); err != nil {")
				p.P("return nil, err")
				p.P("}")
				p.P("m := new(", method.Output.GoIdent, ")")
				p.P("if err := x.ClientStream.RecvMsg(m); err != nil {")
				p.P("return nil, err")
				p.P("}")
				p.P("return m, nil")
				p.P("}")
				p.P()
			}
		} else {
			p.P("var resp ", method.Output.GoIdent)
			p.P(`if err := c.client.Call(ctx, "`, fullName, `", "`, method.Desc.Name(), `", req, &resp); err != nil {`)
			p.P("return nil, err")
			p.P("}")
			p.P("return &resp, nil")
			p.P("}")
			p.P()
		}
	}
}
